import keras.layers as kl
import keras.models as km
import keras.backend as kb
import keras
import tensorflow as tf
import numpy as np
import cv2
import os
from tqdm import tqdm
import random
import istarmap
from imresize_python import imresize
from multiprocessing.pool import Pool
import time

class WarmUpSchedulerPerEpoch(keras.callbacks.Callback):

    def __init__(self,
                 learning_rate_max,learning_rate_min,
                 reached_at,
                 total_steps):

        super(WarmUpSchedulerPerEpoch, self).__init__()
        self.learning_rate_max = learning_rate_max
        self.learning_rate_min = learning_rate_min
        self.reached_at = reached_at
        self.total_steps = total_steps
        self.global_step = 0
        self.learning_rates = []
        self.a = (self.learning_rate_max - self.learning_rate_min) / (self.reached_at - self.total_steps)
        self.b = (self.learning_rate_min * self.reached_at - self.learning_rate_max * self.total_steps) / (
                    self.reached_at - self.total_steps)


    def on_epoch_end(self, epoch, logs=None):
        self.global_step = self.global_step + 1
        lr = kb.get_value(self.model.optimizer.lr)
        self.learning_rates.append(lr)

    def on_epoch_begin(self, batch, logs=None):

        if self.global_step <=self.reached_at:
            #warming up
            lr = self.learning_rate_max*self.global_step/self.reached_at
            lr = max(5e-5,lr)
        else:
            #decaying
            lr = self.a *self.global_step + self.b
            lr = max(self.learning_rate_min, lr)

        kb.set_value(self.model.optimizer.lr, lr)
        print('\nEpoch %05d: setting learning rate to %s.' % (self.global_step + 1, lr))

class KerasLite:
    # Keras Predict Like Wrapper for TFlite Models
    def __init__(self,interpreter: tf.lite.Interpreter):
        self.interpreter = interpreter
        self.input_details = interpreter.get_input_details()
        self.output_details = interpreter.get_output_details()

        self.output_height = self.output_details[0]['shape'][1]
        self.output_width = self.output_details[0]['shape'][2]
        self.interpreter.allocate_tensors()


    def predict(self,input_tensor: np.array):
        self.interpreter.resize_tensor_input(self.input_details[0]["index"], input_tensor.shape)
        self.interpreter.allocate_tensors()
        self.interpreter.set_tensor(self.input_details[0]["index"],input_tensor)
        self.interpreter.invoke()
        output_data = self.interpreter.get_tensor(self.output_details[0]["index"])


        return output_data

class TrainingDataset(keras.utils.Sequence):

    def __init__(self, img_folder_hr, img_folder_lr, batch_size, image_w, image_h, scale, percent = 0.9, epoch_size=1000):

        self.img_folder_hr = img_folder_hr
        self.img_folder_lr = img_folder_lr
        self.bs = batch_size
        num_of_files = int(percent*len(os.listdir(self.img_folder_hr)))
        self.files_hr = sorted(os.listdir(self.img_folder_hr))[:num_of_files]
        self.files_lr = sorted(os.listdir(self.img_folder_lr))[:num_of_files]
        self.num_of_files = len(self.files_hr)
        self.scale = scale
        self.image_w = image_w
        self.image_h = image_h
        self.image_h_lr = int(self.image_h / self.scale)
        self.image_w_lr = int(self.image_w / self.scale)
        self.epoch_size=epoch_size
        self.use_pools = True  #change this to False if you see a file reading error

        # do reading multiprocessed

        if use_pools:
            with Pool(processes=os.cpu_count()) as pool:
                file_names_hr = [self.img_folder_hr + "/" + self.files_hr[idx] for idx in range(len(self.files_hr))]
                self.images_hr = [None] * len(file_names_hr)
                for i, val in tqdm(enumerate(pool.imap(modcrop_imread_RGB, file_names_hr), 0), desc="(Pool) Preloading HR images for speed", total=len(file_names_hr)):
                    self.images_hr[i] = (val)

                file_names_lr = [self.img_folder_lr + "/" + self.files_lr[idx] for idx in range(len(self.files_lr))]
                self.images_lr = [None] * len(file_names_lr)
                for i, val in tqdm(enumerate(pool.imap(imread_RGB_norm, file_names_lr), 0), desc="(Pool) Preloading LR images for speed", total=len(file_names_lr)):
                    self.images_lr[i] = (val)
        else:
                file_names_hr = [self.img_folder_hr + "/" + self.files_hr[idx] for idx in range(len(self.files_hr))]
                self.images_hr = [None] * len(file_names_hr)
                self.images_hr = [modcrop_imread_RGB(file_names_hr[i]) for i in tqdm(range(len(self.files_hr)), desc="(Pool) Preloading HR images for speed", total=len(file_names_hr))]


                file_names_lr = [self.img_folder_lr + "/" + self.files_lr[idx] for idx in range(len(self.files_lr))]
                self.images_lr = [None] * len(file_names_lr)
                self.images_lr = [imread_RGB_norm(file_names_lr[i]) for i in tqdm(range(len(self.files_lr)), desc="(Pool) Preloading LR images for speed", total=len(file_names_lr))]
        


        self.hr_image_patches = np.zeros((self.bs, self.image_h,self.image_w,3),np.float32)
        self.lr_image_patches = np.zeros((self.bs, self.image_h_lr, self.image_w_lr, 3),np.float32)


    def __len__(self):
        return self.epoch_size


    def __getitem__(self, idx):

        ix = 0
        while ix < self.bs:

            idx = int(self.num_of_files * random.random())

            img = self.images_hr[idx]

            if img.shape[1] <= self.image_w or img.shape[0] <= self.image_h:
                continue

            img_lr = self.images_lr[idx]

            trans = int(8 * random.random()) + 1

            if trans == 1:
                img = np.rot90(img, 0)
                img_lr = np.rot90(img_lr, 0)
            elif trans == 2:
                img = np.rot90(img, 1)
                img_lr = np.rot90(img_lr, 1)
            elif trans == 3:
                img = np.rot90(img, 2)
                img_lr = np.rot90(img_lr, 2)
            elif trans == 4:
                img = np.rot90(img, 3)
                img_lr = np.rot90(img_lr, 3)
            elif trans == 5:
                img = np.rot90(img, 0)
                img_lr = np.rot90(img_lr, 0)
                img = np.flip(img, 0)
                img_lr = np.flip(img_lr, 0)
            elif trans == 6:
                img = np.rot90(img, 0)
                img_lr = np.rot90(img_lr, 0)
                img = np.flip(img, 1)
                img_lr = np.flip(img_lr, 1)
            elif trans == 7:
                img = np.rot90(img, 1)
                img_lr = np.rot90(img_lr, 1)
                img = np.flip(img, 0)
                img_lr = np.flip(img_lr, 0)
            elif trans == 8:
                img = np.rot90(img, 1)
                img_lr = np.rot90(img_lr, 1)
                img = np.flip(img, 1)
                img_lr = np.flip(img_lr, 1)


            clr = int(2 * random.random()) + 1

            if clr == 1:
                color_mix = 1
            elif clr == 2:
               color_mix = 0.7
            elif clr == 3:
                color_mix = 0.5


            dim_idy = img.shape
            i = int(((dim_idy[0]-self.image_h) * random.random())/self.scale)*self.scale
            j = int(((dim_idy[1]-self.image_w) * random.random())/self.scale)*self.scale

            i_lr = int(i/self.scale)
            j_lr = int(j/self.scale)


            chs = [0,1,2]
            # random.shuffle(chs)

            self.hr_image_patches[ix, :, :, :] = color_mix * img[i:i + self.image_h, j:j + self.image_w, chs]
            self.lr_image_patches[ix, :, :, :] = color_mix * img_lr[i_lr:i_lr + self.image_h_lr, j_lr:j_lr + self.image_w_lr, chs]  # np.repeat(gauss[:,:,None],self.num_of_lr_img,axis=2)

            ix += 1


        return [self.lr_image_patches, self.hr_image_patches, self.hr_image_patches], self.hr_image_patches

class ValidationDataset(keras.utils.Sequence):

    def __init__(self, img_folder_hr, img_folder_lr, downscale, percent=0.1):

        self.img_folder_hr = img_folder_hr
        self.img_folder_lr = img_folder_lr
        percent = 1- percent
        num_of_files = int(percent * len(os.listdir(self.img_folder_hr)))
        self.files_hr = sorted(os.listdir(self.img_folder_hr))[num_of_files:]
        self.files_lr = sorted(os.listdir(self.img_folder_lr))[num_of_files:]
        self.num_of_files = len(self.files_hr)
        self.downscale = downscale
        self.use_pools = True 


        # do reading multiprocessed
        with Pool(processes=os.cpu_count()) as pool:
            file_names_hr = [self.img_folder_hr + "/" + self.files_hr[idx] for idx in range(len(self.files_hr))]
            self.images_hr = [None] * len(file_names_hr)
            for i, val in tqdm(enumerate(pool.imap(modcrop_imread_RGB, file_names_hr), 0),
                               desc="(Pool) Preloading Validation HR images for speed", total=len(file_names_hr)):
                self.images_hr[i] = (val)

            file_names_lr = [self.img_folder_lr + "/" + self.files_lr[idx] for idx in range(len(self.files_lr))]
            self.images_lr = [None] * len(file_names_lr)
            for i, val in tqdm(enumerate(pool.imap(imread_RGB_norm, file_names_lr), 0), desc="(Pool) Preloading Validation LR images for speed", total=len(file_names_lr)):
                self.images_lr[i] = (val)

            self.images_bicubic = [None] * len(file_names_lr)
            params = [(img, self.downscale) for img in self.images_lr]
            for i, val in tqdm(enumerate(pool.istarmap(imresize, params), 0), desc="(Pool) Upscaling Validation LR images using bicubic interpolation", total=len(file_names_lr)):
                self.images_bicubic[i] = (val)


    def __len__(self):
        return int(len(self.files_hr))

    def __getitem__(self, idx):
        return [self.images_lr[idx][None,:,:,:], self.images_hr[idx][None,:,:,:], self.images_bicubic[idx][None,:,:,:]], self.images_hr[idx][None,:,:,:]

class ModelSaveOnEpochEnd(keras.callbacks.Callback):

    def __init__(self, model_to_save, filename, write_out=True, validation_images=None,downscale=None):
        super(ModelSaveOnEpochEnd, self).__init__()
        self.model_to_save = model_to_save
        self.global_step = 0
        self.best_psnr = -np.inf
        self.learning_rates = []
        self.filename = filename
        self.write_out = write_out
        self.validation_images = validation_images
        self.downscale = downscale


    def on_epoch_end(self, epoch, logs=None):
        self.global_step = self.global_step + 1
        psnr = logs["val_custom_psnr"]
        if psnr > self.best_psnr:
            self.model_to_save.save(self.filename)
            print('\nEpoch %05d: PSNR increased from %f to %f dB' % (self.global_step, self.best_psnr, psnr))
            self.best_psnr = psnr
            if self.write_out:
                model = self.model_to_save
                dataset_prediction(model, self.validation_images, self.downscale)
        else:
            print('\nEpoch %05d: PSNR did not improve current value is %f' % (self.global_step, self.best_psnr))

def custom_charbon(Xorig, Xblurred_sharp):
    return kb.mean(kb.sqrt(kb.square(Xblurred_sharp - Xorig)+1e-2), axis=[1, 2, 3])

def custom_mse(Xorig, Xblurred_sharp):
    return kb.mean(kb.square(Xblurred_sharp - Xorig), axis=[1,2,3])+1e-8

def custom_psnr(y_true, y_pred):
    #calculates mean of psnrs
    return 10 * (kb.log(1 / custom_mse(y_true, y_pred)) / 2.303)

def modcrop(im, modulo):
    sz = im.shape
    h = np.int32(sz[0] / modulo) * modulo
    w = np.int32(sz[1] / modulo) * modulo
    ims = im[0:h, 0:w, ...]
    return ims

def shave(im, border=2):
    border += 6
    border = [border, border]
    im = im[border[0]:-border[0], border[1]:-border[1], ...]
    return im

def keras_shave(im,border=2):
    border += 6
    return im[:,border:-border, border:-border, :]

def imread_RGB_norm(filename,float_type=np.float32):
    img = cv2.imread(filename)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(float_type)
    img = img / 255.0
    return img

def modcrop_imread_RGB(filename,scale=3):
    return modcrop(imread_RGB_norm(filename),scale)

def dataset_prediction(model, validation_images, downscale, use_uint8=False, use_fixed_size=False, write_out=True):

    print("\nPredicting!...")

    save_dir = "./output_images/image_"

    dur = 0

    val_psnr_mean = 0
    int_psnr_mean = 0
    for img_no in tqdm(range(len(validation_images))):

        lr_img = validation_images[img_no][0][0]
        if use_uint8:
            if use_fixed_size:
                lr_img = np.uint8(lr_img[:,:360, :640, :] * 255)
            else:
                lr_img = (255 * lr_img).astype(np.uint8)


        st = time.perf_counter()
        sr_img = model.predict(lr_img) # this is the super resolved image
        if use_uint8:
            sr_img = np.clip(sr_img, 0, 255).astype(np.uint8)
            sr_img = sr_img[0]
        else:
            sr_img = np.clip(sr_img, 0, 1)
            sr_img = sr_img[0]
        dur += time.perf_counter() - st

        hr_img = (validation_images[img_no][1]+0)/1
        hr_img = hr_img[0]

        if use_uint8:
            if use_fixed_size:
                hr_img = np.uint8(np.clip(hr_img[:360 * 3, :640 * 3, :] * 255,0,255))
            else:
                hr_img = (255 * hr_img).astype(np.uint8)

        bicubic_img = validation_images[img_no][0][2]
        bicubic_img = bicubic_img[0]

        if use_uint8:
            if use_fixed_size:
                bicubic_img = np.uint8(np.clip(bicubic_img[:360 * 3, :640 * 3, :] * 255,0,255))
            else:
                bicubic_img = np.clip((255 * bicubic_img).astype(np.uint8),0,255)

        if use_uint8:
            current_psnr = 10 * np.log10((255*255) / np.mean(np.square(shave(hr_img.astype(np.float32), downscale) - shave(sr_img.astype(np.float32), downscale))))
            current_psnr_intp = 10 * np.log10((255*255) / np.mean(np.square(shave(hr_img.astype(np.float32), downscale) - shave(bicubic_img.astype(np.float32), downscale))))
        else:
            current_psnr = 10*np.log10(1/np.mean(np.square(shave(hr_img,downscale)-shave(sr_img,downscale))))
            current_psnr_intp = 10*np.log10(1/np.mean(np.square(shave(hr_img,downscale)-shave(bicubic_img,downscale))))

        val_psnr_mean += current_psnr
        int_psnr_mean += current_psnr_intp


        if use_uint8:
            im_h = cv2.hconcat([hr_img, sr_img, bicubic_img])
        else:
            im_h = 255*cv2.hconcat([hr_img, sr_img, bicubic_img])
            im_h.astype(int)


        if write_out:
            im_h = cv2.cvtColor(im_h,cv2.COLOR_RGB2BGR)
            cv2.putText(im_h,"{:2.3f}/{:2.3f}".format(current_psnr, current_psnr_intp),(20,20),cv2.FONT_HERSHEY_PLAIN,1,(255,0,0),2,cv2.LINE_AA)
            cv2.imwrite(save_dir+str(img_no)+".png", im_h)

    print("\nVal PSNR Mean: %f" % (val_psnr_mean/len(validation_images)))
    print("\nIntp PSNR Mean: %f" % (int_psnr_mean/len(validation_images)))
    print("FPS: %f" % (len(validation_images)/dur))
    return val_psnr_mean/len(validation_images)

def submission_prediction(model,validation_images_lr_folder, use_uint8=False):

    print("\nPredicting!...")

    files_lr = sorted(os.listdir(validation_images_lr_folder))

    save_dir = "./submission/"

    dur = 0

    for img_no in tqdm(range(len(files_lr))):
        lr_img = imread_RGB_norm(validation_images_lr_folder + "/" + files_lr[img_no])
        lr_img = lr_img[None,:,:,:]

        if use_uint8:
            lr_img = np.uint8(lr_img[:, :360, :640, :] * 255)

        st = time.perf_counter()
        sr_img = model.predict(lr_img) # this is the super resolved image

        if use_uint8:
            sr_img = sr_img[0]
            im_h = sr_img
        else:
            sr_img = np.clip(sr_img, 0, 1)
            sr_img = sr_img[0]
            im_h = 255 * sr_img
            im_h.astype(np.uint8)

        dur += time.perf_counter() - st

        im_h = cv2.cvtColor(im_h,cv2.COLOR_RGB2BGR)

        cv2.imwrite(save_dir+str(801+img_no)+".png", im_h)
